{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Pygrid - MNIST Training (Domain-User)\n",
    "In this notebook, we'll be showing how to perform a Remote training with MNIST dataset using PyGrid/PySyft 0.3.0.\n",
    "\n",
    "**Author**: Ionésio Junior **Github**: [@IonesioJunior](https://github.com/IonesioJunior)    \n",
    "**Author**: Madhava Jay **Github**: [@madhavajay](https://github.com/madhavajay)    \n",
    "**Author**: Tudor Cebere **Github**: [@tudorcebere](https://github.com/tudorcebere)    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from syft.grid.grid_client import connect\n",
    "from syft.grid.connections.http_connection import HTTPConnection\n",
    "from syft.core.node.domain.client import DomainClient\n",
    "\n",
    "import syft as sy\n",
    "import torch as th\n",
    "import torchvision as tv\n",
    "\n",
    "sy.VERBOSE = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "client = connect(\n",
    "    url=\"http://localhost:5000\", # Domain Address\n",
    "    credentials= {\"email\": \"user@email.com\", \"password\": \"pwd123\"}, # Admin role\n",
    "    conn_type= HTTPConnection, # HTTP Connection Protocol\n",
    "    client_type=DomainClient) # Domain Client type"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3 - Retrieve some domain setting attributes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# lets get some references to our data owners Duet torch and torchvision\n",
    "torch = client.torch\n",
    "torchvision = client.torchvision\n",
    "\n",
    "# these are the same as the original mnist example\n",
    "transforms = torchvision.transforms\n",
    "datasets = torchvision.datasets\n",
    "nn = torch.nn\n",
    "F = torch.nn.functional\n",
    "optim = torch.optim\n",
    "StepLR = torch.optim.lr_scheduler.StepLR"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create local transforms for our local MNIST data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# we need some transforms for our MNIST data set\n",
    "local_transform_1 = tv.transforms.ToTensor()  # this converts PIL images to Tensors\n",
    "local_transform_2 = tv.transforms.Normalize(0.1307, 0.3081)  # this normalizes the dataset\n",
    "\n",
    "# compose our transforms\n",
    "local_transforms = tv.transforms.Compose([local_transform_1, local_transform_2])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define Training settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training settings from original MNIST example command line args\n",
    "args = {\n",
    "    \"batch_size\": 64,\n",
    "    \"test_batch_size\": 1000,\n",
    "    \"epochs\": 14,\n",
    "    \"lr\": 1.0,\n",
    "    \"gamma\": 0.7,\n",
    "    \"no_cuda\": False,\n",
    "    \"dry_run\": False,\n",
    "    \"seed\": 42, # the meaning of life\n",
    "    \"log_interval\": 10,\n",
    "    \"save_model\": False,\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define our test data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_kwargs = {\n",
    "    \"batch_size\": args[\"test_batch_size\"],\n",
    "}\n",
    "\n",
    "# this is our carefully curated test data which represents the goal of our problem domain\n",
    "test_data = tv.datasets.MNIST('../data', train=False, download=True, transform=local_transforms)\n",
    "test_loader = th.utils.data.DataLoader(test_data,**test_kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_data_length = len(test_loader.dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set Local / Remote Model Architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# lets define our SOTA model to train on the Domain's data\n",
    "# note we subclass from sy.Module not nn.Module\n",
    "fc1_scaling_factor = 0.25  # this can let us scale the fc1 layer down a bit\n",
    "class SyNet(sy.Module):\n",
    "    def __init__(self):\n",
    "        super(SyNet, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 32, 3, 1)\n",
    "        self.conv2 = nn.Conv2d(32, int(64 * fc1_scaling_factor), 3, 1)  # keep fc1 size down\n",
    "        self.dropout1 = nn.Dropout2d(0.25)\n",
    "        self.dropout2 = nn.Dropout2d(0.5)\n",
    "        self.fc1 = nn.Linear(int(9216 * fc1_scaling_factor), 128)\n",
    "        self.fc2 = nn.Linear(128, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.conv2(x)\n",
    "        x = F.relu(x)\n",
    "        x = F.max_pool2d(x, 2)\n",
    "        x = self.dropout1(x)\n",
    "        x = torch.flatten(x, 1)\n",
    "        x = self.fc1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.dropout2(x)\n",
    "        x = self.fc2(x)\n",
    "        output = F.log_softmax(x, dim=1)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# note we subclass from sy.Module not nn.Module\n",
    "# WARNING: be extra careful to use th. not the torch from domain here\n",
    "class LocalSyNet(sy.Module):\n",
    "    def __init__(self):\n",
    "        super(LocalSyNet, self).__init__()\n",
    "        self.conv1 = th.nn.Conv2d(1, 32, 3, 1)\n",
    "        self.conv2 = th.nn.Conv2d(32, int(64 * fc1_scaling_factor), 3, 1)\n",
    "        self.dropout1 = th.nn.Dropout2d(0.25)\n",
    "        self.dropout2 = th.nn.Dropout2d(0.5)\n",
    "        self.fc1 = th.nn.Linear(int(9216 * fc1_scaling_factor), 128)\n",
    "        self.fc2 = th.nn.Linear(128, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = th.nn.functional.relu(x)\n",
    "        x = self.conv2(x)\n",
    "        x = th.nn.functional.relu(x)\n",
    "        x = th.nn.functional.max_pool2d(x, 2)\n",
    "        x = self.dropout1(x)\n",
    "        x = th.flatten(x, 1)\n",
    "        x = self.fc1(x)\n",
    "        x = th.nn.functional.relu(x)\n",
    "        x = self.dropout2(x)\n",
    "        x = self.fc2(x)\n",
    "        output = th.nn.functional.log_softmax(x, dim=1)\n",
    "        return output\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "> Waiting for Blocking Request\n",
      "cuda_is_available: To run test and inference locally.\n",
      "<UID:8f4ac16b-fb65-48c8-883b-67a12a226cec>\n",
      "\n",
      "> INSIDE Request BLOCK 2.1457672119140625e-06 seconds False\n",
      "\n",
      "> Blocking Request ACCEPTED\n",
      "False\n"
     ]
    }
   ],
   "source": [
    "# lets see if our Domain has CUDA\n",
    "has_cuda = False\n",
    "has_cuda_ptr = torch.cuda.is_available()\n",
    "has_cuda = bool(has_cuda_ptr.get(\n",
    "    request_block=True,\n",
    "    request_name=\"cuda_is_available\",\n",
    "    reason=\"To run test and inference locally\",\n",
    "    timeout_secs=30,  # change to something slower\n",
    "))\n",
    "print(has_cuda)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DO device is cpu\n"
     ]
    }
   ],
   "source": [
    "use_cuda = not args[\"no_cuda\"] and has_cuda\n",
    "torch.manual_seed(args[\"seed\"])\n",
    "\n",
    "device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
    "print(f\"DO device is {device.type.get()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6\n",
      "OrderedDict([('conv1', <syft.proxy.torch.nn.Conv2dPointer object at 0x7fbd34d599a0>), ('conv2', <syft.proxy.torch.nn.Conv2dPointer object at 0x7fbdc1e952b0>), ('dropout1', <syft.proxy.torch.nn.Dropout2dPointer object at 0x7fbd34d59fd0>), ('dropout2', <syft.proxy.torch.nn.Dropout2dPointer object at 0x7fbd34d59c40>), ('fc1', <syft.proxy.torch.nn.LinearPointer object at 0x7fbd34d6d430>), ('fc2', <syft.proxy.torch.nn.LinearPointer object at 0x7fbd34d6d910>)])\n"
     ]
    }
   ],
   "source": [
    "model = SyNet()\n",
    "print(len(model.modules))\n",
    "print(model.modules)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> Layer conv1 sum(weight): 0.0\n",
      "> Layer conv1 sum(bias): 0.0\n",
      "> Layer conv2 sum(weight): 0.0\n",
      "> Layer conv2 sum(bias): 0.0\n",
      "> Layer fc1 sum(weight): 0.0\n",
      "> Layer fc1 sum(bias): 0.0\n",
      "> Layer fc2 sum(weight): 0.0\n",
      "> Layer fc2 sum(bias): 0.0\n",
      "OrderedDict([('conv1', Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))), ('conv2', Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1))), ('dropout1', Dropout2d(p=0.25, inplace=False)), ('dropout2', Dropout2d(p=0.5, inplace=False)), ('fc1', Linear(in_features=2304, out_features=128, bias=True)), ('fc2', Linear(in_features=128, out_features=10, bias=True))])\n"
     ]
    }
   ],
   "source": [
    "local_model = LocalSyNet()\n",
    "local_model.zero_layers()  # so we can confirm that the weight download works\n",
    "local_model.debug_sum_layers()\n",
    "print(local_model.modules)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# if we have CUDA lets send our model to the GPU\n",
    "if has_cuda:\n",
    "    model.cuda(device)\n",
    "else:\n",
    "    model.cpu()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set Optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<syft.proxy.syft.lib.python.ListPointer object at 0x7fbd34d59d60> <class 'syft.proxy.syft.lib.python.ListPointer'>\n"
     ]
    }
   ],
   "source": [
    "# lets get our parameters for optimization\n",
    "# params_list required for remote list concatenation\n",
    "params = model.parameters(params_list=client.syft.lib.python.List())\n",
    "print(params, type(params))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<syft.proxy.torch.optim.AdadeltaPointer object at 0x7fbd34d59400> <class 'syft.proxy.torch.optim.AdadeltaPointer'>\n"
     ]
    }
   ],
   "source": [
    "optimizer = optim.Adadelta(params, lr=args[\"lr\"])\n",
    "print(optimizer, type(optimizer))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<syft.proxy.torch.optim.lr_scheduler.StepLRPointer object at 0x7fbd34d6d400> <class 'syft.proxy.torch.optim.lr_scheduler.StepLRPointer'>\n"
     ]
    }
   ],
   "source": [
    "scheduler = StepLR(optimizer, step_size=1, gamma=args[\"gamma\"])\n",
    "print(scheduler, type(scheduler))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define / Execute Remote and Local Training routines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# now can define a simple training loop very similar to the original PyTorch MNIST example\n",
    "@sy.logger.catch\n",
    "def train(args, model, device, train_loader, optimizer, epoch, train_data_length):\n",
    "    # + 0.5 lets us math.ceil without the import\n",
    "    train_batches = round((train_data_length / args[\"batch_size\"]) + 0.5)\n",
    "#     train_batches = 100\n",
    "    print(f\"> Running train in {train_batches} batches\")\n",
    "    model.train()\n",
    "\n",
    "    for batch_idx, data in enumerate(train_loader):\n",
    "#         time.sleep(1)\n",
    "        data_ptr, target_ptr = data[0], data[1]\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data_ptr)\n",
    "        loss = F.nll_loss(output, target_ptr)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        loss_item = loss.item()\n",
    "        train_loss = client.syft.lib.python.Float(0)\n",
    "        train_loss += loss_item\n",
    "        if batch_idx % args[\"log_interval\"] == 0:\n",
    "#             local_loss = loss_item.get(\n",
    "#                 request_name=\"loss\",\n",
    "#                 reason=\"To evaluate training progress\",\n",
    "#                 request_block=True,\n",
    "#                 timeout_secs=30\n",
    "#             )\n",
    "            local_loss = None\n",
    "            if local_loss is not None:\n",
    "                print(\"Train Epoch: {} {} {:.4}\".format(epoch, batch_idx, local_loss))\n",
    "            else:\n",
    "                print(\"Train Epoch: {} {} ?\".format(epoch, batch_idx))\n",
    "            if args[\"dry_run\"]:\n",
    "                break\n",
    "        if batch_idx >= train_batches - 1:\n",
    "            print(\"batch_idx >= train_batches, breaking\")\n",
    "            break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO replace with local inference and local test set\n",
    "# the same for our test training loop except we will need to send our data over for inference\n",
    "@sy.logger.catch\n",
    "def test_local(model, remote_model, test_loader, test_data_length):\n",
    "    # download remote model\n",
    "    model.copy_remote_state(\n",
    "        remote_model=remote_model,\n",
    "        request_name=\"model_download\",\n",
    "        reason=\"test evaluation\",\n",
    "        timeout_secs=30\n",
    "    )\n",
    "    # visually check the weights have changed\n",
    "    model.debug_sum_layers()\n",
    "    # + 0.5 lets us math.ceil without the import\n",
    "    test_batches = round((test_data_length / args[\"test_batch_size\"]) + 0.5)\n",
    "    print(f\"> Running test_local in {test_batches} batches\")\n",
    "    model.eval()\n",
    "    test_loss = 0.0\n",
    "    correct = 0.0\n",
    "\n",
    "    with th.no_grad():\n",
    "        for batch_idx, (data, target) in enumerate(test_loader):\n",
    "#             time.sleep(1)\n",
    "            output = model(data)\n",
    "            iter_loss = th.nn.functional.nll_loss(output, target, reduction='sum').item()\n",
    "            test_loss = test_loss + iter_loss\n",
    "            pred = output.argmax(dim=1)\n",
    "            total = pred.eq(target).sum().item()\n",
    "            correct += total\n",
    "            if args[\"dry_run\"]:\n",
    "                break\n",
    "                \n",
    "            if batch_idx >= test_batches - 1:\n",
    "                print(\"batch_idx >= test_batches, breaking\")\n",
    "                break\n",
    "\n",
    "    accuracy = correct / test_data_length\n",
    "    print(\"Test Set Average Loss:\", 100 * accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'syft.proxy.torchvision.transforms.ToTensorPointer'> <class 'syft.proxy.torchvision.transforms.NormalizePointer'>\n",
      "<syft.proxy.torchvision.datasets.MNISTPointer object at 0x7fbd34d59160>\n",
      "<syft.proxy.torch.utils.data.DataLoaderPointer object at 0x7fbd34d59f10>\n"
     ]
    }
   ],
   "source": [
    "# we need some transforms for our MNIST data set\n",
    "transform_1 = torchvision.transforms.ToTensor()  # this converts PIL images to Tensors\n",
    "transform_2 = torchvision.transforms.Normalize(0.1307, 0.3081)  # this normalizes the dataset\n",
    "print(type(transform_1), type(transform_2))\n",
    "\n",
    "remote_list = client.syft.lib.python.List()\n",
    "remote_list.append(transform_1)\n",
    "remote_list.append(transform_2)\n",
    "\n",
    "# compose our transforms\n",
    "transforms = torchvision.transforms.Compose(remote_list)\n",
    "\n",
    "# The DO has kindly let us initialise a DataLoader for their training set\n",
    "train_kwargs = {\n",
    "    \"batch_size\": args[\"batch_size\"],\n",
    "}\n",
    "train_data_ptr = torchvision.datasets.MNIST('../data', train=True, download=True, transform=transforms)\n",
    "print(train_data_ptr)\n",
    "train_loader_ptr = torch.utils.data.DataLoader(train_data_ptr,**train_kwargs)\n",
    "print(train_loader_ptr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "> Waiting for Blocking Request\n",
      "train_size: To write the training loop.\n",
      "<UID:1c6f7ab9-665a-4afb-a039-e127d62447ef>\n",
      "\n",
      "> INSIDE Request BLOCK 7.152557373046875e-07 seconds False\n",
      "\n",
      "> Blocking Request ACCEPTED\n",
      "Training Dataset size is: 60000\n"
     ]
    }
   ],
   "source": [
    "def get_train_length(train_data_ptr):\n",
    "    train_length_ptr = train_data_ptr.__len__()\n",
    "    train_data_length = train_length_ptr.get(\n",
    "        request_block=True,\n",
    "        request_name=\"train_size\",\n",
    "        reason=\"To write the training loop\",\n",
    "        timeout_secs=30,\n",
    "    )\n",
    "    return train_data_length\n",
    "\n",
    "try:\n",
    "    if train_data_length is None:\n",
    "        train_data_length = get_train_length(train_data_ptr)\n",
    "except NameError:\n",
    "        train_data_length = get_train_length(train_data_ptr)\n",
    "\n",
    "print(f\"Training Dataset size is: {train_data_length}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1\n",
      "> Running train in 938 batches\n",
      "Train Epoch: 1 0 ?\n",
      "Train Epoch: 1 10 ?\n",
      "Train Epoch: 1 20 ?\n",
      "Train Epoch: 1 30 ?\n",
      "Train Epoch: 1 40 ?\n",
      "Train Epoch: 1 50 ?\n",
      "Train Epoch: 1 60 ?\n",
      "Train Epoch: 1 70 ?\n",
      "Train Epoch: 1 80 ?\n",
      "Train Epoch: 1 90 ?\n",
      "Train Epoch: 1 100 ?\n",
      "Train Epoch: 1 110 ?\n",
      "Train Epoch: 1 120 ?\n",
      "Train Epoch: 1 130 ?\n",
      "Train Epoch: 1 140 ?\n",
      "Train Epoch: 1 150 ?\n",
      "Train Epoch: 1 160 ?\n",
      "Train Epoch: 1 170 ?\n",
      "Train Epoch: 1 180 ?\n",
      "Train Epoch: 1 190 ?\n",
      "Train Epoch: 1 200 ?\n",
      "Train Epoch: 1 210 ?\n",
      "Train Epoch: 1 220 ?\n",
      "Train Epoch: 1 230 ?\n",
      "Train Epoch: 1 240 ?\n",
      "Train Epoch: 1 250 ?\n",
      "Train Epoch: 1 260 ?\n",
      "Train Epoch: 1 270 ?\n",
      "Train Epoch: 1 280 ?\n",
      "Train Epoch: 1 290 ?\n",
      "Train Epoch: 1 300 ?\n",
      "Train Epoch: 1 310 ?\n",
      "Train Epoch: 1 320 ?\n",
      "Train Epoch: 1 330 ?\n",
      "Train Epoch: 1 340 ?\n",
      "Train Epoch: 1 350 ?\n",
      "Train Epoch: 1 360 ?\n",
      "Train Epoch: 1 370 ?\n",
      "Train Epoch: 1 380 ?\n",
      "Train Epoch: 1 390 ?\n",
      "Train Epoch: 1 400 ?\n",
      "Train Epoch: 1 410 ?\n",
      "Train Epoch: 1 420 ?\n",
      "Train Epoch: 1 430 ?\n",
      "Train Epoch: 1 440 ?\n",
      "Train Epoch: 1 450 ?\n",
      "Train Epoch: 1 460 ?\n",
      "Train Epoch: 1 470 ?\n",
      "Train Epoch: 1 480 ?\n",
      "Train Epoch: 1 490 ?\n",
      "Train Epoch: 1 500 ?\n",
      "Train Epoch: 1 510 ?\n",
      "Train Epoch: 1 520 ?\n",
      "Train Epoch: 1 530 ?\n",
      "Train Epoch: 1 540 ?\n",
      "Train Epoch: 1 550 ?\n",
      "Train Epoch: 1 560 ?\n",
      "Train Epoch: 1 570 ?\n",
      "Train Epoch: 1 580 ?\n",
      "Train Epoch: 1 590 ?\n",
      "Train Epoch: 1 600 ?\n",
      "Train Epoch: 1 610 ?\n",
      "Train Epoch: 1 620 ?\n",
      "Train Epoch: 1 630 ?\n",
      "Train Epoch: 1 640 ?\n",
      "Train Epoch: 1 650 ?\n",
      "Train Epoch: 1 660 ?\n",
      "Train Epoch: 1 670 ?\n",
      "Train Epoch: 1 680 ?\n",
      "Train Epoch: 1 690 ?\n",
      "Train Epoch: 1 700 ?\n",
      "Train Epoch: 1 710 ?\n",
      "Train Epoch: 1 720 ?\n",
      "Train Epoch: 1 730 ?\n",
      "Train Epoch: 1 740 ?\n",
      "Train Epoch: 1 750 ?\n",
      "Train Epoch: 1 760 ?\n",
      "Train Epoch: 1 770 ?\n",
      "Train Epoch: 1 780 ?\n",
      "Train Epoch: 1 790 ?\n",
      "Train Epoch: 1 800 ?\n",
      "Train Epoch: 1 810 ?\n",
      "Train Epoch: 1 820 ?\n",
      "Train Epoch: 1 830 ?\n",
      "Train Epoch: 1 840 ?\n",
      "Train Epoch: 1 850 ?\n",
      "Train Epoch: 1 860 ?\n",
      "Train Epoch: 1 870 ?\n",
      "Train Epoch: 1 880 ?\n",
      "Train Epoch: 1 890 ?\n",
      "Train Epoch: 1 900 ?\n",
      "Train Epoch: 1 910 ?\n",
      "Train Epoch: 1 920 ?\n",
      "Train Epoch: 1 930 ?\n",
      "batch_idx >= train_batches, breaking\n",
      "> Downloading remote: conv1\n",
      "\n",
      "> Waiting for Blocking Request\n",
      "model_download: test evaluation.\n",
      "<UID:a7f7f397-f830-42d6-98f0-022db1d46eea>\n",
      "\n",
      "> INSIDE Request BLOCK 4.76837158203125e-07 seconds False\n",
      "\n",
      "> Blocking Request ACCEPTED\n",
      ">> Setting weight copy on local conv1\n",
      ">> Setting bias copy on local conv1\n",
      "> Finished downloading model\n",
      "> Downloading remote: conv2\n",
      "\n",
      "> Waiting for Blocking Request\n",
      "model_download: test evaluation.\n",
      "<UID:2bd8d53a-1166-425b-a304-4b26a7bff331>\n",
      "\n",
      "> INSIDE Request BLOCK 2.384185791015625e-06 seconds False\n",
      "\n",
      "> Blocking Request ACCEPTED\n",
      ">> Setting weight copy on local conv2\n",
      ">> Setting bias copy on local conv2\n",
      "> Finished downloading model\n",
      "> Finished downloading model\n",
      "> Finished downloading model\n",
      "> Downloading remote: fc1\n",
      "\n",
      "> Waiting for Blocking Request\n",
      "model_download: test evaluation.\n",
      "<UID:eb0f0d09-0096-4c6f-a3fa-2839d56d3881>\n",
      "\n",
      "> INSIDE Request BLOCK 1.430511474609375e-06 seconds False\n",
      "\n",
      "> Blocking Request ACCEPTED\n",
      ">> Setting weight copy on local fc1\n",
      ">> Setting bias copy on local fc1\n",
      "> Finished downloading model\n",
      "> Downloading remote: fc2\n",
      "\n",
      "> Waiting for Blocking Request\n",
      "model_download: test evaluation.\n",
      "<UID:94d37508-1308-4f95-b2d9-14c129943640>\n",
      "\n",
      "> INSIDE Request BLOCK 4.76837158203125e-07 seconds False\n",
      "\n",
      "> Blocking Request ACCEPTED\n",
      ">> Setting weight copy on local fc2\n",
      ">> Setting bias copy on local fc2\n",
      "> Finished downloading model\n",
      "> Layer conv1 sum(weight): -10.614295959472656\n",
      "> Layer conv1 sum(bias): -1.6140550374984741\n",
      "> Layer conv2 sum(weight): -85.56488037109375\n",
      "> Layer conv2 sum(bias): -0.29050740599632263\n",
      "> Layer fc1 sum(weight): -192.95501708984375\n",
      "> Layer fc1 sum(bias): -2.1897318363189697\n",
      "> Layer fc2 sum(weight): -69.46456909179688\n",
      "> Layer fc2 sum(bias): -0.1939811259508133\n",
      "> Running test_local in 10 batches\n",
      "batch_idx >= test_batches, breaking\n",
      "Test Set Average Loss: 98.17\n",
      "Epoch time: 1019 seconds\n",
      "CPU times: user 6min 1s, sys: 11.6 s, total: 6min 13s\n",
      "Wall time: 16min 59s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "import time\n",
    "\n",
    "args[\"dry_run\"] = False\n",
    "sy.logger.trace(\"Start Training\")\n",
    "for epoch in range(1, args[\"epochs\"] + 1):\n",
    "    epoch_start = time.time()\n",
    "    print(f\"Epoch: {epoch}\")\n",
    "    train(args, model, device, train_loader_ptr, optimizer, epoch, train_data_length)\n",
    "    test_local(local_model, model, test_loader, test_data_length)  # real local data and model\n",
    "    scheduler.step()\n",
    "    epoch_end = time.time()\n",
    "    print(f\"Epoch time: {int(epoch_end - epoch_start)} seconds\")\n",
    "    break\n",
    "sy.logger.trace(\"Finish Training\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Running Inferences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "def draw_image_and_label(image, label):\n",
    "    fig = plt.figure()\n",
    "    plt.tight_layout()\n",
    "    plt.imshow(image, cmap=\"gray\", interpolation=\"none\")\n",
    "    plt.title(\"Ground Truth: {}\".format(label))\n",
    "    \n",
    "def prep_for_inference(image):\n",
    "    image_batch = image.unsqueeze(0).unsqueeze(0)\n",
    "    image_batch = image_batch * 1.0\n",
    "    return image_batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "def classify_local(image):\n",
    "    image_tensor = th.Tensor(prep_for_inference(image))\n",
    "    print(\"1\", type(image_tensor))\n",
    "    output = local_model(image_tensor)\n",
    "    print(\"2\", type(output))\n",
    "    preds = th.exp(output)\n",
    "    print(\"3\", type(preds))\n",
    "    local_y = preds\n",
    "    local_y = local_y.squeeze()\n",
    "    pos = local_y == max(local_y)\n",
    "    index = th.nonzero(pos, as_tuple=False)\n",
    "    class_num = index.squeeze()\n",
    "    return class_num, local_y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "def classify(image):\n",
    "    image_tensor_ptr = torch.Tensor(prep_for_inference(image))\n",
    "    output = model(image_tensor_ptr)\n",
    "    preds = torch.exp(output)\n",
    "    preds_result = preds.get(\n",
    "        request_block=True,\n",
    "        request_name=\"inference\",\n",
    "        reason=\"To see a real world example of inference\",\n",
    "        timeout_secs=10\n",
    "    )\n",
    "    if preds_result is None:\n",
    "        print(\"No permission to do inference, request again\")\n",
    "        return -1, th.Tensor([-1])\n",
    "    else:\n",
    "        local_y = th.Tensor(preds_result)\n",
    "        local_y = local_y.squeeze()\n",
    "        pos = local_y == max(local_y)\n",
    "        index = th.nonzero(pos, as_tuple=False)\n",
    "        class_num = index.squeeze()\n",
    "        return class_num, local_y\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Random Test Image: 5011\n",
      "Displaying 5011 == 11 in Batch: 5/10\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAEICAYAAACZA4KlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAARoUlEQVR4nO3de7BV9XnG8e8j8ZKIE1EiImKIlzpNmoqKWoPNmNEoQa2kkyHatNUm9hjRtNHgZeKMWjVTR811pjrFkYrVGpMqoCRIDE20NB0VqCBqopRiBLl4qUVE5fb2j72O2eJZv33OvnN+z2fmzNlnvXut9Z7NeVhrr8v+KSIws8Fvl043YGbt4bCbZcJhN8uEw26WCYfdLBMOu1kmHHYDQNIYSSHpAx1Y90pJJ7d7vblx2NtI0lmSHpP0pqT1xeMpktTp3lIkbaz62i7praqfvzTAZd0h6fom9/c1Sf8jaYOkhZJOaObyBwuHvU0kfQP4PnATsD8wAvgqMB7YrWSeIW1rMCEihvZ+Ab8Fzqiadnfv8zq0V3AccAPwBeDDwO3AzG557bqJw94Gkj4MXAtMiYh/jYg3ouK/IuJLEfFO8bw7JN0q6aeS3gQ+I+n3Jf1S0uuSnpb0J1XL/aWk86p+PlfSgqqfQ9JXJT1fzP8PvXsRkoZIulnSK5JWAKfV8XudKGmVpMslrQX+acceqvo4VFIP8CXgsmKv4MGqp42VtFTS/0m6V9Ie/WxjDPB0RCyKyuWgdwLDgf0G+vsMdg57exwP7A7M7sdz/wz4FrAX8BjwIPAzKn+8XwPulnT4ANZ9OnAM8IfAZODUYvpfF7UjgXFUtoz12B/YB/go0JN6YkRMA+4Gbiz2Cs6oKk8GJgAfK3o9t7dQ/EdVtms+Fxgi6bhia/5l4ElgbV2/zSDmsLfHcOCViNjaO0HSr4o/4rckfbrqubMj4j8iYjswFhgK3BARmyPi34A5wNkDWPcNEfF6RPwW+EWxTKiE63sR8WJEvAb8fZ2/23bg6oh4JyLeqnMZAD+IiJeKXh6s6pOI2DsiFpTM9wZwH7AAeAe4GugJ3/TxPg57e7wKDK9+TxsRn4qIvYta9b/Di1WPDwBeLILf6wVg1ADWXb2F20TlP493l73DcuvxckS8Xee81cr6rOUrwF8Bn6By7OPPgTmSDmhCT4OKw94e/0llq3NmP55bvUV6CRgtqfrf6SBgdfH4TeBDVbX9B9DTGmD0Dsutx45b0Pf0JGnHnpq9xR0LzImI5yJie0Q8ROV3+1ST17PTc9jbICJeB/4OuEXSFyTtJWkXSWOBPROzPkZlK3eZpF0lnQicAfywqD8J/KmkD0k6lMpWrr9+BPyNpAMlDQOuGMC8KUuAT0gaWxxku2aH+jrg4CatC+AJ4DRJB6vis8DvAcuauI5BwWFvk4i4EbgEuIzKH/w64B+By4FflcyzmUq4Pwe8AtwC/GVE/Lp4yneBzcWyZlA5+NVftwHzqIRzMXD/wH6jvkXEc1TOPPwceJ7Ke+lqtwMfL45XzOrPMosj939cUr6Tyn9+vwQ2AD8Azq96jawgH8cwy4O37GaZcNjNMuGwm2XCYTfLRFtvXJDko4FmLRYRfd5F2dCWXdIESb+RtFxSs87TmlkL1H3qrbjp4Dngs8AqKhc3nB0RzyTm8ZbdrMVasWU/FlgeESuKiz9+SP8uBzWzDmgk7KN4740Uq+jjBg1JPcWnhyxsYF1m1qCWH6Ar7mGeBt6NN+ukRrbsq3nvXVMH8ru7scysyzQS9ieAwyR9TNJuwFnAA81py8yare7d+IjYKukiKndODQGmR8TTTevMzJqqrXe9+T27Weu15KIaM9t5OOxmmXDYzTLhsJtlwmE3y4TDbpYJh90sEw67WSYcdrNMOOxmmXDYzTLhsJtlwmE3y4TDbpYJh90sEw67WSYcdrNMOOxmmXDYzTLhsJtlwmE3y0Rbh2w2G4iTTjopWZ8zZ06yPnXq1NLaLbfckpy3nZ+63C7esptlwmE3y4TDbpYJh90sEw67WSYcdrNMOOxmmfAorrbTqnUe/ic/+Ulp7bjjjkvOu2TJkrp66gZlo7g2dFGNpJXAG8A2YGtEjGtkeWbWOs24gu4zEfFKE5ZjZi3k9+xmmWg07AH8TNIiST19PUFSj6SFkhY2uC4za0Cju/EnRMRqSfsBD0v6dUQ8Wv2EiJgGTAMfoDPrpIa27BGxuvi+HpgJHNuMpsys+eoOu6Q9Je3V+xg4BVjWrMbMrLka2Y0fAcyU1Lucf4mIh5rSlVk/zJ8/P1l/6KHyP8djjjkmOe/OfJ69TN1hj4gVwBFN7MXMWsin3swy4bCbZcJhN8uEw26WCYfdLBP+KOnMffCDH0zW99hjj2R927ZtyfqGDRsG3FOzvPrqq6W10aNHt7GT7uAtu1kmHHazTDjsZplw2M0y4bCbZcJhN8uEw26WCZ9nH+Q+8pGPJOup20ABjjrqqGR948aNyfr06dNLaz/+8Y+T8y5evDhZ33vvvZP1yZMnl9amTJmSnHcw8pbdLBMOu1kmHHazTDjsZplw2M0y4bCbZcJhN8uEh2weBFLn0ufOnZuc9+ijj25o3S+//HKynrpffujQocl5H3/88WR91apVyfqxx5aPWXL44Ycn5920aVOy3s3Khmz2lt0sEw67WSYcdrNMOOxmmXDYzTLhsJtlwmE3y4TvZx8Err322tJarfPoL730UrJ+3nnnJeuPPPJIsp66BuDCCy9MznvppZcm66nz6ACXXHJJaW1nPo9er5pbdknTJa2XtKxq2j6SHpb0fPF9WGvbNLNG9Wc3/g5gwg7TrgDmR8RhwPziZzPrYjXDHhGPAq/tMPlMYEbxeAYwqbltmVmz1fuefURErCkerwVGlD1RUg/QU+d6zKxJGj5AFxGRusElIqYB08A3wph1Ur2n3tZJGglQfF/fvJbMrBXqDfsDwDnF43OA2c1px8xapeb97JLuAU4EhgPrgKuBWcCPgIOAF4DJEbHjQby+luXd+DqMHz8+WU999nut8dFPPfXUZH3ZsmXJei2HHHJIaW3evHnJeYcNS5/RnTp1arJ+1113lda2bNmSnHdnVnY/e8337BFxdknppIY6MrO28uWyZplw2M0y4bCbZcJhN8uEw26WCd/iuhOodZtq6iOZFyxYkJx37dq1yfruu++erNe6DfWiiy4qrQ0ZMiQ571lnnZWsP/zww8m6vZe37GaZcNjNMuGwm2XCYTfLhMNulgmH3SwTDrtZJjxk807g0EMPTdaXLl1aWksNmQywYsWKupcNMGnSpGR91qxZpbWLL744Oe/KlSuTdeubh2w2y5zDbpYJh90sEw67WSYcdrNMOOxmmXDYzTLh8+yDwCmnnFJau+6665Lz1hr2uJaZM2cm66l70jdv3tzQuq1vPs9uljmH3SwTDrtZJhx2s0w47GaZcNjNMuGwm2XCnxs/CGzatKm0Nnr06Jaue8yYMcn6Lrt4e9Itav5LSJouab2kZVXTrpG0WtKTxdfE1rZpZo3qz3+7dwAT+pj+3YgYW3z9tLltmVmz1Qx7RDwKvNaGXsyshRp5Q3WRpKXFbv6wsidJ6pG0UNLCBtZlZg2qN+y3AocAY4E1wLfLnhgR0yJiXESMq3NdZtYEdYU9ItZFxLaI2A7cBjR265SZtVxdYZc0surHzwPLyp5rZt2h5nl2SfcAJwLDJa0CrgZOlDQWCGAlcH7rWrSTTz45Wb/33ntLa++8805y3ilTpiTrV111VbJ+5JFHJusHHnhgaW358uXJea25aoY9Is7uY/LtLejFzFrIlzeZZcJhN8uEw26WCYfdLBMOu1kmfItrFxg/fnyyfuuttybrb7/9dmmt1qm12bNnJ+sXXHBBsr7//vsn69Y9vGU3y4TDbpYJh90sEw67WSYcdrNMOOxmmXDYzTLh8+xtMHLkyGT9pptuStYPOuigZP2LX/xiaa3WefT99tsvWT/ggAOS9Y0bNybrqWsArL28ZTfLhMNulgmH3SwTDrtZJhx2s0w47GaZcNjNMuHz7G1w+umnJ+vHH398sn799dcn67NmzRpoS++aOnVqsr7vvvsm67V6W7Vq1YB7stbwlt0sEw67WSYcdrNMOOxmmXDYzTLhsJtlwmE3y4QiIv0EaTRwJzCCyhDN0yLi+5L2Ae4FxlAZtnlyRPxvjWWlVzZIPfjgg8n6brvtlqyfdtppyfrWrVtLaxMmTEjOe9dddyXrtc6zT5w4MVmfO3dusm7NFxHqa3p/tuxbgW9ExMeBPwIulPRx4ApgfkQcBswvfjazLlUz7BGxJiIWF4/fAJ4FRgFnAjOKp80AJrWoRzNrggG9Z5c0BjgSeAwYERFritJaKrv5Ztal+n1tvKShwH3A1yNig/S7twUREWXvxyX1AD2NNmpmjenXll3SrlSCfndE3F9MXidpZFEfCazva96ImBYR4yJiXDMaNrP61Ay7Kpvw24FnI+I7VaUHgHOKx+cA6Y8xNbOO6s9u/HjgL4CnJD1ZTPsmcAPwI0lfAV4AJrekw53ABz6QfhlHjRqVrK9duzZZr3X67IgjjiitXXXVVcl5a532q/Ux1/PmzUvWrXvUDHtELAD6PG8HnNTcdsysVXwFnVkmHHazTDjsZplw2M0y4bCbZcJhN8tEzVtcm7qyQXqL66677pqsL1q0KFn/5Cc/2cx2BuTmm29O1i+//PJkffv27c1sx5qgkVtczWwQcNjNMuGwm2XCYTfLhMNulgmH3SwTDrtZJjxkcxNs2bIlWa81pHKt8+y1hj1eunRpae3KK69MzrtkyZJkvZ3XYVhrectulgmH3SwTDrtZJhx2s0w47GaZcNjNMuGwm2XC97ObDTK+n90scw67WSYcdrNMOOxmmXDYzTLhsJtlwmE3y0TNsEsaLekXkp6R9LSkvy2mXyNptaQni6+JrW/XzOpV86IaSSOBkRGxWNJewCJgEjAZ2BgR6VEG3rssX1Rj1mJlF9XU/KSaiFgDrCkevyHpWWBUc9szs1Yb0Ht2SWOAI4HHikkXSVoqabqkYSXz9EhaKGlhY62aWSP6fW28pKHAI8C3IuJ+SSOAV4AArqOyq//lGsvwbrxZi5Xtxvcr7JJ2BeYA8yLiO33UxwBzIuIPaizHYTdrsbpvhJEk4Hbg2eqgFwfuen0eWNZok2bWOv05Gn8C8O/AU0Dv+LzfBM4GxlLZjV8JnF8czEsty1t2sxZraDe+WRx2s9bz/exmmXPYzTLhsJtlwmE3y4TDbpYJh90sEw67WSYcdrNMOOxmmXDYzTLhsJtlwmE3y4TDbpYJh90sEzU/cLLJXgFeqPp5eDGtG3Vrb93aF7i3ejWzt4+WFdp6P/v7Vi4tjIhxHWsgoVt769a+wL3Vq129eTfeLBMOu1kmOh32aR1ef0q39tatfYF7q1dbeuvoe3Yza59Ob9nNrE0cdrNMdCTskiZI+o2k5ZKu6EQPZSStlPRUMQx1R8enK8bQWy9pWdW0fSQ9LOn54nufY+x1qLeuGMY7Mcx4R1+7Tg9/3vb37JKGAM8BnwVWAU8AZ0fEM21tpISklcC4iOj4BRiSPg1sBO7sHVpL0o3AaxFxQ/Ef5bCIuLxLeruGAQ7j3aLeyoYZP5cOvnbNHP68Hp3Ysh8LLI+IFRGxGfghcGYH+uh6EfEo8NoOk88EZhSPZ1D5Y2m7kt66QkSsiYjFxeM3gN5hxjv62iX6aotOhH0U8GLVz6vorvHeA/iZpEWSejrdTB9GVA2ztRYY0clm+lBzGO922mGY8a557eoZ/rxRPkD3fidExFHA54ALi93VrhSV92DddO70VuAQKmMArgG+3clmimHG7wO+HhEbqmudfO366Kstr1snwr4aGF3184HFtK4QEauL7+uBmVTednSTdb0j6Bbf13e4n3dFxLqI2BYR24Hb6OBrVwwzfh9wd0TcX0zu+GvXV1/tet06EfYngMMkfUzSbsBZwAMd6ON9JO1ZHDhB0p7AKXTfUNQPAOcUj88BZnewl/folmG8y4YZp8OvXceHP4+Itn8BE6kckf9v4MpO9FDS18HAkuLr6U73BtxDZbduC5VjG18B9gXmA88DPwf26aLe/pnK0N5LqQRrZId6O4HKLvpS4Mnia2KnX7tEX2153Xy5rFkmfIDOLBMOu1kmHHazTDjsZplw2M0y4bCbZcJhN8vE/wNj2ab60p1hlQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# lets grab something from the test set\n",
    "import random\n",
    "total_images = test_data_length # 10000\n",
    "index = random.randint(0, total_images)\n",
    "print(\"Random Test Image:\", index)\n",
    "count = 0\n",
    "batch = index // test_kwargs[\"batch_size\"]\n",
    "batch_index = index % int(total_images / len(test_loader))\n",
    "for tensor_ptr in test_loader:\n",
    "    data, target = tensor_ptr[0], tensor_ptr[1]\n",
    "    if batch == count:\n",
    "        break\n",
    "    count += 1\n",
    "\n",
    "print(f\"Displaying {index} == {batch_index} in Batch: {batch}/{len(test_loader)}\")\n",
    "image_1 = data[batch_index].reshape((28, 28))\n",
    "label_1 = target[batch_index]\n",
    "draw_image_and_label(image_1, label_1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run Local model inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 <class 'torch.Tensor'>\n",
      "2 <class 'torch.Tensor'>\n",
      "3 <class 'torch.Tensor'>\n",
      "Prediction: 8 Ground Truth: 8\n",
      "tensor([1.9781e-05, 1.2030e-05, 3.8570e-05, 1.6585e-05, 2.9105e-06, 4.2355e-04,\n",
      "        7.8849e-06, 2.1061e-06, 9.9940e-01, 7.7318e-05],\n",
      "       grad_fn=<SqueezeBackward0>)\n"
     ]
    }
   ],
   "source": [
    "# classify local\n",
    "sy.logger.trace(\"Before running classify\")\n",
    "class_num, preds = classify_local(image_1)\n",
    "print(f\"Prediction: {class_num} Ground Truth: {label_1}\")\n",
    "print(preds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run Remote model inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "> Waiting for Blocking Request\n",
      "inference: To see a real world example of inference.\n",
      "<UID:01e016e3-3172-48e9-839e-15ab9d71a293>\n",
      "\n",
      "> INSIDE Request BLOCK 1.1920928955078125e-06 seconds False\n",
      "\n",
      "> Blocking Request ACCEPTED\n",
      "Prediction: 8 Ground Truth: 8\n",
      "tensor([8.4406e-06, 1.8314e-05, 2.1068e-06, 3.2389e-06, 5.2382e-05, 5.4962e-05,\n",
      "        1.4191e-06, 1.9418e-08, 9.9982e-01, 3.6174e-05],\n",
      "       grad_fn=<SqueezeBackward0>)\n"
     ]
    }
   ],
   "source": [
    "# classify remote\n",
    "sy.logger.trace(\"Before running classify\")\n",
    "class_num, preds = classify(image_1)\n",
    "print(f\"Prediction: {class_num} Ground Truth: {label_1}\")\n",
    "print(preds)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
